Code is stolen from https://pymc-devs.github.io/pymc3/notebooks/dawid-skene.html
I collect code here for easy reference and my experiments.
The data can be found from this repo https://github.com/abhishekmalali/questioning-strategy-classification
The model follows the implementation in https://aclweb.org/anthology/W/W13/W13-2323.pdf
In [1]:
%matplotlib inline
import pymc3 as pm
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
In [2]:
data_file = 'data/extrahard_MC_500_5_4.npz.npy'
truth_file = 'data/extrahard_MC_500_5_4_reference_classes.npy'
data = np.load( data_file )
z_true = np.load( truth_file )
I = data.shape[0] # number of items
J = data.shape[1] # number of annotators
K = data.shape[2] # number of classes
N = I * J
In [3]:
# create data triplets
jj = list() # annotator IDs
ii = list() # item IDs
y = list() # response
# initialize true category with majority votes
z_init = np.zeros( I, dtype=np.int64 )
# create data triplets
for i in range( I ):
ks = list()
for j in range( J ):
dat = data[ i, j, : ]
k = np.where( dat == 1 )[0][0]
ks.append( k )
ii.append( i )
jj.append( j )
y.append( k )
# getting maj vote for work item i (dealing with numpy casts)
z_init[ i ] = np.bincount( np.array( ks ) ).argmax()
In [4]:
confMat = confusion_matrix( z_true, z_init )
print( "Majority vote estimate of true category:\n" , confMat )
In [5]:
# class prevalence (flat prior)
alpha = np.ones( K )
# individual annotator confusion matrices - dominant diagonal
beta = np.ones( (K,K) ) + np.diag( np.ones(K) )
In [6]:
model = pm.Model()
with model:
pi = pm.Dirichlet( 'pi', a=alpha, shape=K ) # r the probability that an item is of category k
theta = pm.Dirichlet( 'theta', a=beta, shape=(J,K,K) )
z = pm.Categorical( 'z', p=pi, shape=I, testval=z_init ) # the true category of item i
y_obs = pm.Categorical( 'y_obs', p=theta[ jj, z[ ii ] ], observed=y)
In [7]:
with model:
step1 = pm.Metropolis( vars=[pi,theta] )
step2 = pm.CategoricalGibbsMetropolis( vars=[z] )
trace = pm.sample( 5000, step=[step1, step2], progressbar=True )